import scalevi.models.models_branched as models_branched 
import scalevi.models.models as models 
import scalevi.dataloader as dataloader
import scalevi.utils.utils as utils

def get_model_args(config_dict):
    args = {
    'data': dataloader.get_data(config_dict, nb = False)
    }
    if hasattr(models_branched, config_dict['model']):
        args.update({'N_chunk': config_dict['N_leaves']})
    return args

def get_model_class(config_dict):
    return utils.get_attribute([models, models_branched], config_dict['model'])

def get_model(config_dict):
    return get_model_class(config_dict)(**get_model_args(config_dict))